#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_cdf.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::depends(RcppGSL)]]



#include <Ziggurat.h>
// [[Rcpp::depends(RcppZiggurat)]]

static Ziggurat::Ziggurat::Ziggurat zigg;

//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]


const double log2pi = std::log(2.0 * M_PI);



double truncn(double bound, bool lb, double mu, double sigma){
  
  double c, z, w;
  
  // 1. standardised cut-off c for truncation from below or above 
  if(lb == TRUE){
    c = (bound-mu)/sigma;
  } else{
    c = -(bound-mu)/sigma;
  }
  
  // 2. standardised draw using Geweke's (1991)
  if(c < 0.45){ // normal rejection sampling
   // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    while(z < c){
     // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    } 
  } else{ // exponential rejection sampling
    z = -log(1-::Rf_runif(0.0,1.0))/c;
    w = ::Rf_runif(0.0,1.0);
    while(w > exp(-0.5*pow(z,2))){
      z = -log(1-::Rf_runif(0.0,1.0))/c;
      w = ::Rf_runif(0.0,1.0);
    }
    z = z+c;
  }

  // 3. reverse standardisation
  if(lb == TRUE){
    return mu + sigma*z;
  } else{
    return mu - sigma*z;
  }
}



 double get1TN(double mu,
 double sd,
 double low,
 double high
 ) {
 double draw = 0.0 ;
int valid = 0 ;
while (valid==0) {
double cand = mu+zigg.norm()*sd;
if ((cand >= low) &
 (cand <= high)
 ) {
 draw = cand ;
 valid = 1 ;
 }
 }
 return(draw) ;
}


double vm_mult(const arma::mat& rhs)
{
  arma::mat ip= rhs.t() * rhs;
 return (ip(0));
}

double vm_convert(const arma::mat& rhs)
{
  arma::mat ip= rhs;
 return (ip(0));
}







// [[Rcpp::export()]]
Rcpp::List mcmc_probit (arma::mat Y,
arma::mat I,
arma::vec tablecountry,
arma::colvec Cinit_conv,
arma::colvec Cinit_unconv,
arma::mat alphastart,
arma::mat alphastart_conv,
arma::mat alphastart_unconv,
int mcmc = 100,
int burn=10,
int thin=10,
int chains=2
 ) {

 // Dimensions
int N = Y.n_rows ;
int nitems = Y.n_cols;
int ncountry=I.n_cols;


 // Current Containers
int N1_conv=round(N/2);
int N1_unconv=round(N/2);
arma::mat P_conv(2,1);
P_conv.fill(0.5) ;
arma::mat P_unconv(2,1);
P_unconv.fill(0.5) ;
arma::mat ystar(N, nitems) ;
ystar.fill(0.0) ;
 arma::mat mu(N, nitems) ;
 arma::colvec C_conv = Cinit_conv ;
 arma::colvec C_unconv = Cinit_unconv ;
 arma::mat alpha = alphastart ;
 arma::mat alpha_conv = alphastart_conv ;
 arma::mat alpha_unconv = alphastart_unconv ;
arma::mat  probC_conv(N, 1) ;
 arma::mat  probC_unconv(N, 1) ;
 arma::mat  pnum(N, 2) ;
 arma::mat  pnum_unconv(N, 2) ;
arma::mat  p(N, 1) ;
arma::mat  p_unconv(N, 1) ;
 arma::vec sigma2_alphaintercept(nitems);
 sigma2_alphaintercept.fill(1.0) ;

 arma::vec sigma2_alphaconv(nitems);
 sigma2_alphaconv.fill(1.0) ;

 arma::vec sigma2_alphaunconv(nitems);
 sigma2_alphaunconv.fill(1.0) ;

arma::mat y(N,1);
y.fill(0.0);


 int lastit= (mcmc-burn)/thin	;

// Trace Containers
  arma::cube tracealpha(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaunconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracesigma2alpha(nitems,lastit,chains);
  arma::cube tracesigma2alphaconv(nitems,lastit,chains);
  arma::cube tracesigma2alphaunconv(nitems,lastit,chains);
  arma::cube tracePunconv(2, lastit,chains) ;
  arma::cube tracePconv(2, lastit,chains) ;


omp_set_num_threads(chains);

#pragma omp parallel for 
for (int chain = 0; chain < chains; ++chain) {

gsl_rng *s = gsl_rng_alloc(gsl_rng_mt19937);


 // START SAMPLING
 for (int iter = 0 ; iter < mcmc ; iter++) {
// Update C_conv
P_conv(0)= gsl_ran_beta(s, N1_conv+4, N-N1_conv+4);
P_conv(1)=1-P_conv(0);

for (int k = 0 ; k < 2 ; k++) {
arma::mat lindividualitem(N,nitems);
for (int l=0; l<nitems; l++) {
arma::mat mutmp=I*alpha.row(l).t()+(I*alpha_conv.row(l).t())*k+(I*alpha_unconv.row(l).t())%(C_unconv-1);
for (int n = 0 ; n < N ; n++) {
lindividualitem(n,l)=pow(gsl_cdf_ugaussian_P(mutmp(n)), Y(n,l))* pow(gsl_cdf_ugaussian_P(-mutmp(n)), 1-Y(n,l));
}
}
pnum.col(k)=prod(lindividualitem,1)*P_conv(k);
}

p=sum(pnum, 1);
probC_conv=pnum.col(1)/p;

for (int n=0; n<N; n++){
C_conv(n)= gsl_ran_binomial(s,probC_conv(n), 1)+1;
}

N1_conv=N-vm_convert(sum(C_conv-1,0));




// Update C_unconv
P_unconv(0)= gsl_ran_beta(s, N1_unconv+4, N-N1_unconv+4);
P_unconv(1)=1-P_unconv(0);

for (int k = 0 ; k < 2 ; k++) {
arma::mat lindividualitem_unconv(N,nitems);

for (int l=0; l<nitems; l++) {
arma::mat mutmp_unconv=I*alpha.row(l).t()+(I*alpha_conv.row(l).t())%(C_conv-1)+(I*alpha_unconv.row(l).t())*k;
for (int n = 0 ; n < N ; n++) {
lindividualitem_unconv(n,l)=pow(gsl_cdf_ugaussian_P(mutmp_unconv(n)), Y(n,l))* pow(gsl_cdf_ugaussian_P(-mutmp_unconv(n)), 1-Y(n,l));
}
}
pnum_unconv.col(k)=prod(lindividualitem_unconv,1)*P_unconv(k);
}
p_unconv=sum(pnum_unconv, 1);
probC_unconv=pnum_unconv.col(1)/p_unconv;

for (int n=0; n<N; n++){
C_unconv(n)=gsl_ran_binomial(s,probC_unconv(n), 1)+1;
}

N1_unconv=N-vm_convert(sum(C_unconv-1,0));


// Update item-specific parameters
for (int l=0; l<nitems; l++) {

mu.col(l) = I*alpha.row(l).t()+(I*alpha_conv.row(l).t())%(C_conv-1)+(I*alpha_unconv.row(l).t())%(C_unconv-1);

for (int n = 0 ; n < N ; n++) {
 if (Y(n, l) == 1) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 0,
 INFINITY
 ) ;
 }
 if (Y(n, l) == 0) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 -INFINITY,
 0
 ) ;
 }
 }

arma::mat llreffectalpha=I.t()*(ystar.col(l)-(I*alpha_conv.row(l).t())%(C_conv-1)-(I*alpha_unconv.row(l).t())%(C_unconv-1));


for (int c=0; c<ncountry; c++) {
double valpha=1/(tablecountry(c)+(1/sigma2_alphaintercept(l)));
double balpha=valpha * llreffectalpha(c);
alpha(l,c)= balpha+gsl_ran_gaussian_ziggurat(s,sqrt(valpha));

}


sigma2_alphaintercept(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha.row(l))/2));


}

for (int l=0; l<(nitems-1); l++) {
for (int c=0; c<ncountry; c++) {
double valpha_conv=1/(vm_mult(I.col(c)%(C_conv-1))+(1/sigma2_alphaconv(l)));
double balpha_conv=valpha_conv * vm_convert((I.col(c)%(C_conv-1)).t()*(I.col(c)%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_unconv.row(l).t())%(C_unconv-1))));
alpha_conv(l,c)=truncn(0.0,TRUE, balpha_conv,  sqrt(valpha_conv));
}

sigma2_alphaconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_conv.row(l))/2));


}


for (int l=1; l<nitems; l++) {
for (int c=0; c<ncountry; c++){
double valpha_unconv=1/(vm_mult(I.col(c)%(C_unconv-1))+(1/sigma2_alphaunconv(l)));
double balpha_unconv=valpha_unconv * vm_convert((I.col(c)%(C_unconv-1)).t()*(I.col(c)%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_conv.row(l).t())%(C_conv-1))));
alpha_unconv(l,c)=truncn(0.0,TRUE, balpha_unconv,  sqrt(valpha_unconv));
}

sigma2_alphaunconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_unconv.row(l))/2));

}




 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracealpha.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha);
tracealphaconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_conv);
tracealphaunconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_unconv);
tracesigma2alpha.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaintercept;
tracesigma2alphaconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaconv;
tracesigma2alphaunconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaunconv;
tracePconv.subcube(0, j, chain, 1, j, chain)=P_conv;
tracePunconv.subcube(0, j, chain, 1, j, chain)=P_unconv;
}



}

gsl_rng_free(s);
 
}

// Returns
Rcpp::List ret;

 ret["alpha"] = tracealpha;
 ret["alpha_conv"] = tracealphaconv;
 ret["alpha_unconv"] = tracealphaunconv;
ret["sigma2_alphaintercept"] = tracesigma2alpha;
ret["sigma2_alphaconv"] = tracesigma2alphaconv;
ret["sigma2_alphaunconv"] = tracesigma2alphaunconv;
ret["P_conv"] = tracePconv;
ret["P_unconv"] = tracePunconv;

 return(ret) ;
 

}

